package com.github.xzb617.client.flow.adapte.boot;

import com.github.xzb617.client.flow.cache.CounterCache;
import com.github.xzb617.client.flow.core.*;
import com.github.xzb617.client.flow.core.hotparam.HotParamFlowLimiter;
import com.github.xzb617.client.flow.core.rules.RulesFlowLimiter;
import com.github.xzb617.client.flow.core.statics.Statics;
import com.github.xzb617.client.flow.core.system.SystemProtectedFlowLimiter;
import com.github.xzb617.client.flow.fallback.FlowFallback;
import com.github.xzb617.client.flow.fallback.FlowResponse;
import com.github.xzb617.client.flow.utils.DateUtil;
import com.github.xzb617.client.flow.utils.ResponseUtil;
import com.github.xzb617.client.flow.utils.WebUtil;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * 流量控制Spring过滤器
 * @author xzb617
 */
public class FlowControlBootFilter extends OncePerRequestFilter {

    private final FlowProperties flowProperties;
    private final CounterCache counterCache;
    private final FlowFallback flowFallback;

    public FlowControlBootFilter(FlowProperties flowProperties, CounterCache counterCache, FlowFallback flowFallback) {
        this.flowProperties = flowProperties;
        this.counterCache = counterCache;
        this.flowFallback = flowFallback;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        // 统计QPS
        long currentTimestamp = DateUtil.getCurrentTimestamp();
        Statics.addTotal(currentTimestamp);

        // 组合FlowContext
        long qps = Statics.getCurrentQPS(currentTimestamp);
        String ip = WebUtil.getClientIpAddr(request);
        String uri = request.getRequestURI();
        FlowContext flowContext = new FlowContext(qps, ip, uri, request.getParameterMap());

        // 执行Chain
        FlowLimiterChain chain = this.buildLimiterChain();
        try {
            chain.execute(flowContext, flowProperties, counterCache, chain);
        } catch (FlowException fe) {
            // 捕获限流异常
            FlowResponse flowResponse = this.flowFallback.fallback(fe);
            ResponseUtil.writeJSON(response, flowResponse.getResponseStatus(), flowResponse.getResponseContent());
            return;
        }

        // 如果执行到这一步没有抛出限流异常，则代表成功通过
        Statics.addPass(currentTimestamp);

        filterChain.doFilter(request, response);
    }


    /**
     * 构建限流器链
     * @return
     */
    private FlowLimiterChain buildLimiterChain() {
        List<FlowLimiter> flowLimiters = new ArrayList<>();
        flowLimiters.add(new SystemProtectedFlowLimiter());
        flowLimiters.add(new HotParamFlowLimiter());
        flowLimiters.add(new RulesFlowLimiter());
        return new FlowLimiterChain(flowLimiters);
    }

}
